{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e407034e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7092316d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_9892\\3970781211.py\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      5\u001b[0m \u001b[0mbatch_size\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;36m4\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      6\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 7\u001b[1;33m trainset = torchvision.datasets.CIFAR100(root='./data/cifar100', train=True,\n\u001b[0m\u001b[0;32m      8\u001b[0m                                         download=True, transform=transform)\n\u001b[0;32m      9\u001b[0m trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\cifar.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, root, train, transform, target_transform, download)\u001b[0m\n\u001b[0;32m     65\u001b[0m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     66\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 67\u001b[1;33m         \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_check_integrity\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     68\u001b[0m             \u001b[1;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"Dataset not found or corrupted. You can use download=True to download it\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     69\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\cifar.py\u001b[0m in \u001b[0;36m_check_integrity\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    129\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mfilename\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmd5\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain_list\u001b[0m \u001b[1;33m+\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtest_list\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    130\u001b[0m             \u001b[0mfpath\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mroot\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbase_folder\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 131\u001b[1;33m             \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mcheck_integrity\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfpath\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    132\u001b[0m                 \u001b[1;32mreturn\u001b[0m \u001b[1;32mFalse\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    133\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\utils.py\u001b[0m in \u001b[0;36mcheck_integrity\u001b[1;34m(fpath, md5)\u001b[0m\n\u001b[0;32m     68\u001b[0m     \u001b[1;32mif\u001b[0m \u001b[0mmd5\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     69\u001b[0m         \u001b[1;32mreturn\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 70\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0mcheck_md5\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfpath\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     71\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     72\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\utils.py\u001b[0m in \u001b[0;36mcheck_md5\u001b[1;34m(fpath, md5, **kwargs)\u001b[0m\n\u001b[0;32m     60\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     61\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0mcheck_md5\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfpath\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mstr\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mbool\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 62\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0mmd5\u001b[0m \u001b[1;33m==\u001b[0m \u001b[0mcalculate_md5\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfpath\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     63\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     64\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mD:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\utils.py\u001b[0m in \u001b[0;36mcalculate_md5\u001b[1;34m(fpath, chunk_size)\u001b[0m\n\u001b[0;32m     55\u001b[0m     \u001b[1;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfpath\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"rb\"\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     56\u001b[0m         \u001b[1;32mwhile\u001b[0m \u001b[0mchunk\u001b[0m \u001b[1;33m:=\u001b[0m \u001b[0mf\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mchunk_size\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 57\u001b[1;33m             \u001b[0mmd5\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mchunk\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     58\u001b[0m     \u001b[1;32mreturn\u001b[0m \u001b[0mmd5\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mhexdigest\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     59\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(),\n",
    "     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "\n",
    "batch_size = 4\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR100(root='./data/cifar100', train=True,\n",
    "                                        download=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,\n",
    "                                          shuffle=True, num_workers=2)\n",
    "\n",
    "testset = torchvision.datasets.CIFAR100(root='./data/cifar100', train=False,\n",
    "                                       download=True, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,\n",
    "                                         shuffle=False, num_workers=2)\n",
    "\n",
    "CIFAR100_class = 100  # 数据集的分类类别数量"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c68cc8fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = ('apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'computer_keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3f84f90",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# functions to show an image\n",
    "\n",
    "\n",
    "def imshow(img):\n",
    "    img = img / 2 + 0.5     # unnormalize\n",
    "    npimg = img.numpy()\n",
    "    plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "# get some random training images\n",
    "dataiter = iter(trainloader)\n",
    "images, labels = next(dataiter)\n",
    "\n",
    "# show images\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "# print labels\n",
    "print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "089b10b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# 定义Residual Block\n",
    "class ResidualBlock(nn.Module):\n",
    "    def __init__(self, in_channels, out_channels, stride=1):\n",
    "        super(ResidualBlock, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(out_channels)\n",
    "        self.relu1 = nn.ReLU()\n",
    "        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)\n",
    "        self.bn2 = nn.BatchNorm2d(out_channels)\n",
    "        self.relu2 = nn.ReLU()\n",
    "        \n",
    "        self.shortcut = nn.Sequential()\n",
    "        if stride != 1 or in_channels != out_channels:\n",
    "            self.shortcut = nn.Sequential(\n",
    "                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),\n",
    "                nn.BatchNorm2d(out_channels)\n",
    "            )\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.bn2(self.conv2(out))\n",
    "        out += self.shortcut(x)\n",
    "        out = F.relu(out)\n",
    "        return out\n",
    "\n",
    "# 定义ResNet\n",
    "class ResNet(nn.Module):\n",
    "    def __init__(self, block, layers, num_classes=100):\n",
    "        super(ResNet, self).__init__()\n",
    "        self.in_channels = 32\n",
    "        self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(32)\n",
    "        self.layer1 = self.make_layer(block, 32, 32, 2, stride=1)\n",
    "        self.layer2 = self.make_layer(block, 32, 64, 2, stride=2)\n",
    "        self.layer3 = self.make_layer(block, 64, 64, 2, stride=1)\n",
    "        self.layer4 = self.make_layer(block, 64, 128, 2, stride=2)\n",
    "        self.layer5 = self.make_layer(block, 128, 128, 2, stride=1)\n",
    "        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))\n",
    "        #self.fc = nn.Linear(4096, 1024)\n",
    "        #self.fc2 = nn.Linear(1024, num_classes)\n",
    "        self.fc = nn.Linear(128, num_classes)\n",
    "        self.relu = nn.ReLU()\n",
    "\n",
    "    def make_layer(self, block, in_channels, out_channels, blocks, stride):\n",
    "        layers = []\n",
    "        layers.append(block(in_channels, out_channels, stride))\n",
    "        self.in_channels = out_channels\n",
    "        for i in range(1, blocks):\n",
    "            layers.append(block(out_channels, out_channels, stride=1))\n",
    "        return nn.Sequential(*layers)\n",
    "\n",
    "    def forward(self, x):\n",
    "        out = F.relu(self.bn1(self.conv1(x)))\n",
    "        out = self.layer1(out)\n",
    "        out = self.layer2(out)\n",
    "        out = self.layer3(out)\n",
    "        out = self.layer4(out)\n",
    "        #out = self.layer4(out)\n",
    "        out = self.avg_pool(out)\n",
    "        out = out.view(out.size(0), -1)\n",
    "        out = self.fc(out)\n",
    "        #out = self.fc2(out)\n",
    "        out = self.relu(out)\n",
    "        return out\n",
    "    \n",
    "def resnet18(num_classes=100):\n",
    "    return ResNet(ResidualBlock, [2, 2, 2, 2])\n",
    " \n",
    "\n",
    "def resnet50(num_classes=100):\n",
    "    return ResNet(ResidualBlock, [3, 4, 6, 3])\n",
    "\n",
    "def res2net101(num_classes=100):\n",
    "    return ResNet(ResidualBlock, [3, 4, 23, 3])\n",
    "    \n",
    "net = resnet18()\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "net.to(device)\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "22bc09ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "#optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=0.0004)\n",
    "PATH = './cifar100_resnet.pth'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e608592c",
   "metadata": {},
   "outputs": [],
   "source": [
    "lossv, accv = [], []\n",
    "for epoch in range(30):  # loop over the dataset multiple times\n",
    "    \n",
    "    running_loss = 0.0\n",
    "    for i, data in enumerate(trainloader, 0):\n",
    "        # get the inputs; data is a list of [inputs, labels]\n",
    "        inputs, labels = data\n",
    "\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # print statistics\n",
    "        running_loss += loss.item()\n",
    "\n",
    "        if i % 2000 == 1999:    # print every 2000 mini-batches\n",
    "            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')\n",
    "            lossv.append(running_loss / 2000)\n",
    "            running_loss = 0.0\n",
    "            torch.save(net.state_dict(), PATH)\n",
    "\n",
    "print('Finished Training')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea34d50d",
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH = './cifar100_resnet.pth'\n",
    "torch.save(net.state_dict(), PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2ba66156",
   "metadata": {},
   "outputs": [],
   "source": [
    "net=resnet18()\n",
    "net.load_state_dict(torch.load(PATH))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6fb44ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataiter = iter(testloader)\n",
    "images, labels = next(dataiter)\n",
    "\n",
    "# print images\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f75a97d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = net(images)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6302520e",
   "metadata": {},
   "outputs": [],
   "source": [
    "_, predicted = torch.max(outputs, 1)\n",
    "\n",
    "print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'\n",
    "                              for j in range(4)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e5cc23b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "# since we're not training, we don't need to calculate the gradients for our outputs\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        # calculate outputs by running images through the network\n",
    "        outputs = net(images)\n",
    "        # the class with the highest energy is what we choose as prediction\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "        accv.append(100 * correct // total)\n",
    "\n",
    "\n",
    "print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')\n",
    "print(correct)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c6abb2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare to count predictions for each class\n",
    "correct_pred = {classname: 0 for classname in classes}\n",
    "total_pred = {classname: 0 for classname in classes}\n",
    "\n",
    "# again no gradients needed\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        outputs = net(images)\n",
    "        _, predictions = torch.max(outputs, 1)\n",
    "        # collect the correct predictions for each class\n",
    "        for label, prediction in zip(labels, predictions):\n",
    "            if label == prediction:\n",
    "                correct_pred[classes[label]] += 1\n",
    "            total_pred[classes[label]] += 1\n",
    "\n",
    "\n",
    "# print accuracy for each class\n",
    "for classname, correct_count in correct_pred.items():\n",
    "    accuracy = 100 * float(correct_count) / total_pred[classname]\n",
    "    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43f768da",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(5,3))\n",
    "plt.plot(np.linspace(1, 15, len(lossv)), lossv)\n",
    "plt.title('validation loss')\n",
    "\n",
    "plt.figure(figsize=(5,3))\n",
    "plt.plot(np.linspace(1, 15, len(accv)), accv)\n",
    "plt.title('validation accuracy')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
